label smooth的pytorch实现以及其公式推导(虽然短但是细)

您所在的位置:网站首页 label smoothing作用 label smooth的pytorch实现以及其公式推导(虽然短但是细)

label smooth的pytorch实现以及其公式推导(虽然短但是细)

2023-12-14 04:39| 来源: 网络整理| 查看: 265

标签平滑:label smooth

标签平滑是一种正则化手段,目的为了解决onehot编码的缺陷,减少过拟合问题。在各种竞赛中广泛使用,涨点神器。

假设: 预测的结果为 y p r e d y_{pred} ypred​, 真实结果为 y t r u e y_{true} ytrue​,类别数量为 ∗ ∗ N **N ∗∗N,标签平滑因子为 ϵ , 即 e p s i l o n \epsilon,即epsilon ϵ,即epsilon

标签平滑即在 y t r u e y_{true} ytrue​的one-hot编码中进行处理。 y n e w t r u e = ( 1 − ϵ ) ∗ y t r u e + ϵ / N { 1 , 0 , 0 } = > { ϵ = 0.1 } = > { 0.933 , 0.033 , 0.033 } { 0 , 1 , 0 } = > { ϵ = 0.5 } = > { 0.16 , 0.66 , 0.16 } y_{new_true} = (1 - \epsilon) * y_{true} + \epsilon / N \\ \{1, 0 ,0\} => \{\epsilon = 0.1\} = > \{0.933,0.033,0.033\} \\ \{0, 1 ,0\} => \{\epsilon = 0.5\} = > \{0.16,0.66,0.16\} ynewt​rue​=(1−ϵ)∗ytrue​+ϵ/N{1,0,0}=>{ϵ=0.1}=>{0.933,0.033,0.033}{0,1,0}=>{ϵ=0.5}=>{0.16,0.66,0.16} 在多分类中,往往采用交叉熵作为损失函数,如何将标签平滑和交叉熵进行结合,下面我们进行推导: L c r o s s _ e n t r o y = − ∑ y t r u e   l o g   y p r e d L c r o s s _ e n t r o y _ l a b e l s m o o t h = − ∑ ( ( 1 − ϵ ) ∗ y t r u e + ϵ / N )   l o g   y p r e d = − ∑ ( 1 − ϵ ) ∗ y t r u e   l o g   y p r e d − ∑ ϵ / N   l o g   y p r e d = − ( 1 − ϵ ) ∑ y t r u e   l o g   y p r e d − ϵ / N ∑   l o g   y p r e d = ( 1 − ϵ ) ∗ L c r o s s _ e n t r o y − ϵ / N ∑   l o g   y p r e d L_{cross\_entroy} = -\sum y_{true}~log~y_{pred} \\ L_{cross\_entroy\_labelsmooth} = -\sum ((1 - \epsilon) * y_{true} + \epsilon / N)~log~y_{pred} \\ = -\sum (1 - \epsilon) * y_{true}~log~y_{pred} - \sum \epsilon / N~log~y_{pred}\\ =-(1 - \epsilon) \sum y_{true}~log~y_{pred} - \epsilon / N\sum~log~y_{pred} \\ = (1 - \epsilon) * L_{cross\_entroy} - \epsilon / N\sum~log~y_{pred} Lcross_entroy​=−∑ytrue​ log ypred​Lcross_entroy_labelsmooth​=−∑((1−ϵ)∗ytrue​+ϵ/N) log ypred​=−∑(1−ϵ)∗ytrue​ log ypred​−∑ϵ/N log ypred​=−(1−ϵ)∑ytrue​ log ypred​−ϵ/N∑ log ypred​=(1−ϵ)∗Lcross_entroy​−ϵ/N∑ log ypred​ 根据公式的最后一行,我们知道使用标签平滑的交叉熵损失,只需要在原来的损失函数上乘上一个因子 ( 1 − ϵ ) (1-\epsilon) (1−ϵ),并减去因子 ( ϵ / N ) (\epsilon /N) (ϵ/N)和预测结果对数之和的乘积。

代码:https://github.com/lonePatient/label_smoothing_pytorch/blob/master/lsr.py(来自网上)

import torch.nn as nn import torch.nn.functional as F ## eps 表示标签平滑因子 class LabelSmoothingCrossEntropy(nn.Module): def __init__(self, eps=0.1, reduction='mean'): super(LabelSmoothingCrossEntropy, self).__init__() self.eps = eps self.reduction = reduction def forward(self, output, target): c = output.size()[-1] log_preds = F.log_softmax(output, dim=-1) if self.reduction=='sum': loss = -log_preds.sum() else: loss = -log_preds.sum(dim=-1) if self.reduction=='mean': loss = loss.mean() return loss*self.eps/c + (1-self.eps) * F.nll_loss(log_preds, target, reduction=self.reduction)


【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3